{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 3])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 3 words sentences (=sequence_length is 3)\n",
    "sentences = [\"i love you\", \"he loves me\", \"she likes baseball\", \"i hate you\", \"sorry for that\", \"this is awful\"]\n",
    "labels = [1, 1, 1, 0, 0, 0]  # 1 is good, 0 is not good.\n",
    "\n",
    "#parameters\n",
    "embedding_size = 2\n",
    "n_class = 2\n",
    "n_hidden = 5\n",
    "\n",
    "words = ' '.join(sentences)\n",
    "words = list(set(words.split()))\n",
    "vocab_size = len(words)\n",
    "word_id = {w:i for i,w in enumerate(words)}\n",
    "\n",
    "input_batch = []\n",
    "target_batch = []\n",
    "\n",
    "for sen in sentences:\n",
    "    input_batch.append(np.asarray([word_id[w] for w in sen.split()]))\n",
    "    \n",
    "for out in labels:\n",
    "    target_batch.append(out)\n",
    "\n",
    "input_batch = Variable(torch.LongTensor(input_batch))\n",
    "target_batch = Variable(torch.LongTensor(target_batch))\n",
    "\n",
    "input_batch.size()\n",
    "# X = nn.Embedding(vocab_size,embedding_size)(input_batch)\n",
    "# hidden = Variable(torch.zeros(2,len(X[1]),n_hidden))#hidden_state\n",
    "\n",
    "# cell = Variable(torch.zeros(2,len(X[1]),n_hidden))\n",
    "# outputs,(final_hidden,final_cell) =nn.LSTM(embedding_size,n_hidden,bidirectional=True)(X,(hidden,cell))#\n",
    "# final_hidden.size()#torch.Size([6, 3, 10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class BiLSTM(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(BiLSTM,self).__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size,embedding_size)\n",
    "        self.lstm = nn.LSTM(embedding_size,n_hidden,bidirectional=True)\n",
    "        self.out = nn.Linear(n_hidden*2,n_class)\n",
    "        #outputs [batch_size,seq_len,hidden_size*2] [3,6,10]\n",
    "    def self_attn(self,final_outputs,final_hidden):# #final_hidden[2,batch_size,hidden_size] 2,3,5\n",
    "        final_hidden = final_hidden.view(-1,2*n_hidden,1)#[batch_size,2*hidden_size,1]\n",
    "        attn = torch.bmm(final_outputs,final_hidden).squeeze(2)#[batch_size,seq_len] 6,3\n",
    "        soft_attn_weights = F.softmax(attn,1)#[batch_size,seq_len]\n",
    "        #6 10 3 6,3,1\n",
    "        context = torch.bmm(final_outputs.transpose(1,2),soft_attn_weights.unsqueeze(2)).squeeze(2) # 6,10\n",
    "        return context,soft_attn_weights.data.numpy()\n",
    "    def forward(self,X):\n",
    "        input = self.embedding(X)#torch.Size([6, 3, 2])[batch_size,n_step,embedding_size]\n",
    "        input = input.permute(1,0,2)#[n_step,batch_size,embedding_size]\n",
    "        hidden = Variable(torch.zeros(2,len(X),n_hidden))#hidden_state\n",
    "        cell = Variable(torch.zeros(2,len(X),n_hidden))#cell_state\n",
    "        #outputs [seq_len,batch_size,hidden_size*2]\n",
    "        #final_hidden[2,batch_size,hidden_size]\n",
    "        #final_cell[2,batch_size,hidden_size]\n",
    "        outputs,(final_hidden,final_cell) = self.lstm(input,(hidden,cell))#\n",
    "        outputs = outputs.permute(1,0,2)\n",
    "        attn_output,attention = self.self_attn(outputs,final_hidden)\n",
    "        return self.out(attn_output),attention\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:1000,loss:0.004434\n",
      "epoch:2000,loss:0.000778\n",
      "epoch:3000,loss:0.000288\n",
      "epoch:4000,loss:0.000134\n",
      "epoch:5000,loss:0.000069\n"
     ]
    }
   ],
   "source": [
    "model = BiLSTM()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(),lr=0.001)\n",
    "\n",
    "for i in range(5000):\n",
    "    optimizer.zero_grad()\n",
    "    output,attn = model(input_batch)\n",
    "    loss = criterion(output,target_batch)\n",
    "    if (i+1) %1000 ==0:\n",
    "        print('epoch:%d,loss:%f' % (i+1,loss))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sorry hate you is Bad Mean...\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJsAAAEICAYAAABI5IbVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAE61JREFUeJztnXu0XUV9xz9f8oS8SAiYEAghPARB\nAgUFozwUkccSkiIWEArhrdCgwGrRFrAIylpqwFLBNoumYAUjCvJKpV0FIwUEmpBggEB45MUjYCCE\nkIgJya9/zD7cfW9u7t333tlzztn5fdY6656zzz7zm1n5Zmbv2fP7jswMx0nBFvWugLP54GJzkuFi\nc5LhYnOS4WJzkuFic5LhYnOS4WJzkuFic5LRu94VcAKSfgt0+jjHzD6XoDql4GJrHG7Ovf8L4GTg\nJmARsCNwDnB38lpFRP5stPGQtAA41cyeyB37JHCzmX2sfjXrGX7N1pgMAfq0OdYb2KYOdYmGD6ON\nyY+BeyRNB5YC2xOG1Z/UtVY9xIfRBkXSccBEYCTwJnCPmd1R31r1DBebkwwfRjtA0kKKTUeMjRx3\nIvC/ZvZWzHJz5delXS62jpmUe38UcDwwhTAdMRq4BLi/hLjXAWcDD5RQNtSpXT6MFkTSK8ARZjY/\nd+xjwAwz2zlyrG8ARwATzOyDmGW3EytZu7xnK44BuwLzc8d2BvqWEGsFYfpjjqSfAO99WAmzn0aO\nlaxdLrbiXAn8UtIDtExHHAFcXEKsScA6YDnw5dxxA2KLLVm7fBjtApLGARNomY6418xm1bdWPSdV\nu1xsDYyk7QgX7IvN7I/1rk9P8cdVBZF0oaRRiWINkfRrYBnwKLBM0p2Sti4hVrJ2udiK8zVgn0Sx\nbiBcT+9kZn0JF+x9suOxSdYuH0YLIulkYDJwpJmtKjnWW8D+ZrYod2wsMMvMhkWOlaxdfjdanN0J\nd4MvSrqF1tMR34kcawnwOWBa7thns+OxSdYuF1txxgALste22QsKPPbpBl8HZmS9zsvAWOAg4Isl\nxBpDonb5MNqgSNoBOJWwSncpcKuZLa1vrXqGi62LZCtma9MR/1dSjCnAk9nrOUvwj5SkXS62YmTT\nA3cDewCvEyZAnyc8v3wlcqxrCddSuxOGtflk4jOzaR39thux0rXLxVYMSfcBrwEXmtn7krYErgdG\nmdkxkWP1Ijyv3IMwLXEcsBfwgpmNixwrXbtcbMWQ9A6wj5ktyR0bDfzBzKJOtkpaCzwG3AvMJois\nlOu1lO3ySd3izANOa3PsdOCZEmKdCTwEHAh8C7hJ0lRJl5QQK1m7vGcriKS9gf8CVgMLCbP6A4Cj\nzGxe5FgHALsRrtl2Az5OuCudbWZHRI6Vrl0utuJIGgAcS8t0xH1m9l7Hv+pWnDcINwRzCMPok2a2\nMHacXLw07XKxFUPSHbRMRzxpZm/UsS4PAyeY2bIIZSVrl4utIJIupGU6YjfCg/G5hH+gKxLXZQUw\nLn9R34OykrXLH1cV5y5gT8J0xJ7AOML0xPp6VioCydrlPVtBJG0AHiZMR/wnMN/MNtSpLjF7tmTt\ncrEVRNKnaekBdgcGA2uB581scuK6xBRbsnb5PFtxRhIeHW0DDCNc32wL9K9DXWL2EMna5T1bQSQ9\nQctd22xgnpmtrVNdYvZsydrlYouEpCXA+NgPrzcRa0fg1RTXjDHb5Xej8RhEDy5Lsgv1Iv4bvRKv\na+tRu/K42BqHvNXBJOALhATiRYR1ZlcQMq2aFh9GIxH5Omo58In8IypJOwMPm1mStLtc3Gjt8rvR\nxmQFIeElz6HA+3WoSzR8GI1HzCHiImC6pItp8d/YFTglYoyiRGuXiy0eilWQmd0naRfgSFr8N/7b\nzF6NFaMLRGuXX7P1AEl9gN5m9qcSyu4LfIQ2/9iRrglHFzkvRqxWcV1sxZA0DTjfzN7PHTuEsDdB\nbJvTCwhOkHl7egFmZr0ilF+bZhGth8lWn2PEyuM3CMU5nY0N8p4Hyrg7/A7BarR/Nq/Wy8y2iPWP\nXyvLzLYALqNl5Ud/4KPAncB3Y8TK4z1bJ0iqrc+/GTgfWFP7CjgcGGNmh0SOuQg4Om89WhZZwss4\nM1ucOzaa4CuyXcxYfoPQOWdkf41wN1jzuN1AsCw4uYSYk4Gpkr5qZmUk1ORZDEySdJWZbZAkQgLM\nm7EDec9WkOw6Z2szezdBrIWEVRgDCHNuH8Ys4frwQOAOYCDBD2474M/Al83s4ZixvGcrzr8R/hFS\nMClRHMzs8Wya5UBaplkeN7M1Hf+y63jP1k2yzHHL352WEKPmv7Ekv0Nfs+J3owWRdLCkZ7P3E4G3\ngZWSos/qSxolaRbwIHAN8KCk2ZmzUexYt2f7HpSOi604NxKGUoBvA18FzgOuKiHWvxIWMw43s9rK\n2SeBqSXEGknYTLd0fBgtiKTVwC6Eh+Evm9mwrKd5zswGRo6V0ldkP+A/gLPN7LGYZbfFbxCK8zTw\nTaAf8FD2qOoUWu+MEoua/8bVuWNl+YpcSxjhHpE0m9Y2p1H3o/eerSCS9iUMpe8THLZHALcBJ8ae\nIkjsK3L6pr4zs1uixnKxdQ9JKtMRMpX/RkpcbA2MpF5mtj6bB1tsJe/QVzZ+N1oQST/M9iJIEWtf\nSS8BX8oO3QEsyIbXpsXFVpzPEnwwUnAjcDvBDgFgf8JCgBtjFJ5fzyZp9KZeMWK1iuvDaDEkfZ6w\nxuy4/AqJkmKtAnY3s9dzx7YnTLMMjlD+h89526xtyxNl7Vwen/oozikED4znJN1L6ymCMyPHeoqw\n2uR7uWNnEKZEeky2jm2j92XjYivOoux1X4JY5wP3Z9MSi4CdCIYvRyeIXRo+jDYokgYStg+qTX3M\nsJI3MisbF1uDkj2h2IWwd9UoYFkZiTVZrDEEUbdNrnkoZhwfRhsQSQcD0wnZVWOBbwAnSjo29lY/\nkr4HXEpYOJl3L7IsdrxY3rM1HpLmEjay/T6Z9UG2B8LJZnZA5FgrgdPM7O6Y5baHz7N1gKT1kgZn\n7zdkn/OvDZLK8NQdA/ymzbHbCUZ9sXkVKN3mC3wY7YyxuZyDnTs8My6PA58nDGWWrQr+e6CM1bpn\nADdIugr4XZnPX30YbUCyx2L3EHI5lxKSUBYBx5rZSxHKb+sFV7sxsNzn6JO6LrYOkDQeeKkeG2wo\n7Mx3ILADQXDzYvU6knbq5JR+BFuJZ2PEq+HXbB1zFzAEQNLL2dxX6WR3o/PM7FHCHeKDwFux8h3M\nbHHtRVji/kabYyMoYfLae7YOyJaCjzWzN7KhZ0iKiVVJ8wgeIlMkzQF+RBjarighb3Q9MDSfDyvp\nI4SMrn5RY7nYNo2kewj/y58m5HLeRuu5KCD+s9EU+Q71sJXwu9GO+Qph7f9wwsXzUtK4P6bId0hu\nK+E9W0ES2y+kzHdI1i7v2Yrz14QElNIxs7nA+NpnSQtKNG5OZivhPZuTDJ/6cJLhYusGks71WF3H\nxdY9kgmgSrFcbE4yNusbhL69t7L+/bru07LugzX06b1Vl34zfOzKLscBWPX2Bwwa1vVJg+ULu9Ou\n1fTpPaDLv1u1+rXlZrZtZ+dt1lMf/fttzUF7nZck1qRbZySJU2PaGccli/XAI5cXSm30YdRJhovN\nSYaLzUmGi81JhovNSYaLzUmGi81JhovNSUanYpN0WLZLXOlIGiOp2480JA2VNF3SCTHr5cSh1J5N\n0kxJk8qMkYt1BWHXlRNTxHO6TpWG0esJWev12EfdKUBRsUnSjZJWS3pE0q7ZwfGS5khaI+mJ2h5I\nkm7KhsNDgX+XZJJuzhW2T1bOe5Iey3YZyQc7SNJ8Se9KuqZIBc3sHTNbREvihtNgFBXbaEIWzl6E\nDJ9bJQn4JfBrgrXSA8APs/MnA0OBR4ALsvfnA0gaRNhQ4n7CFtEzCdtE57mekNV0GnBpTJduSedK\nmiVp1roPou9y6HRA0VUf64BLzew9Sf9A8PLakWAP8CZBhFsTxENmWvcnSR8Aa8zsnVxZXwTeM7Or\nACR9F5ibpavVuLq25aGkWqyXu9nGVpjZVLINxwYP2H7zXV9VB4qKbXnNZyLLDl9H2M1tAmF151LC\n9s9FjEh2JJikkJW3imB8R+gsgdDb1VjLxk7WThNSdBjdRtJWAJKGA32A/sDXCbvH7Uf72xNuYGOh\nLCUYEpOV11vSU5I+9B5LkcPopKeo2PoC12TuN1cTrNMHE67jBmVuP1PYWFgLgMMljZB0qKR+wAxg\nsKTLJY0i+I4NIfSMToUpKralhA3rnwL2I9gS3E9wuplN2Iz1JmD7zJSkxtWEXmwJcAvBhuldgsX6\n0cDzwBcIG1ls5KHhVItOr9nMbCbhbhTgrDZfn9Tm85Q2v30FOLidMueQy/jOHV/Exo7VYzqrY0/O\nd9LRNDkIkv5Ai+jbMtbM3k5ZH6frNI3YgGPYdH3f2cRxp4FoGrFlQ7LTxFTp2ajT4LjYnGQ0zTBa\nCmvex2Y9nSTU1Mlf6vykiNxz+4+TxRpS0DnOezYnGS42JxkuNicZLjYnGS42JxkuNicZLjYnGS42\nJxkuNicZLjYnGZWxX5C0jaQ7Ja2U9IykwyJXz+khlbFfAK4h5DKMI+Sw/kJS/0SxnQJUaRg9BvhR\ntrT8Xwj7qu9R1xo5raiM/QIh/2FW9n6b7G+KvUGdglTGfsHMrjOz17OPlwCPm9lzbc9rZb+QZudD\nJ6Ny9guSrgUOAj7V3vet7Bc0zO0XElIp+4VMaH8FHGZmC4v8xklHZewXJJ1DyGM9pL3h06k/lbBf\nkDQQ+D5wMfC2pK2zV9/ulunEpyr2C/sTrhl/DqzIvb7SgzKdyFTCfsHMftf2d07j0TTZVW6/0Pw0\njdhw+4Wmp2nE5vYLzU+Vno06DY6LzUmGi81JRtNcs5XB2rFbsvCafZLE2vL3aeeXjz3rbxJG+2ah\ns7xnc5LhYnOS4WJzkuFic5LhYnOS4WJzkuFic5LhYnOS4WJzklEl+4WRkn6T5aLOlXRQ7Po5PaNK\n9gs3ELK2PgrcAfwiUVynIJUYRrOE6UHAD8zsVeBXwGhJA+pbMydPJewXLHCEmb0saQvgFOAFM1td\nsH1OAipjv5DjWeAi4Mz2vszbL6x/17WYksrZLwATgUsJ/yH2NLM1+S/z9gv9dxnl9gsJKdqztbJf\nIIhvJKG3eg2Yln3ulv2CmU03s3W5c2bm3ndqvyBpsKRPZeU9R0hWHg0cXqA+TiKqYr8wAni0VkfC\nkF+L7zQIlbBfAF4iZNdflpV5KSEj/vc9KNOJTCXsF8xsPXACIfv+eeBIYIInLjcWlbBfyM57ur1Y\nTuPQNAkvbr/Q/DSN2HD7haanacTm9gvNTyWejTrNgYvNSYaLzUlG01yzlUG/V9az27dWpgmmLntS\n94gZD9+VLFavkcXO857NSYaLzUmGi81JhovNSYaLzUmGi81JhovNSYaLzUmGi81JRmXsF3Jl7CHp\nzwkz8Z2CVMl+ocYNhJwJp8Go1DAq6SRCSuGCetfF2ZhK2C9kvxlEyIG4gJDX6jQYVbJfuBKYaWa/\n7eikvP3C2g1rOjrViUwl7BckfZwgzr07a0jefmFIvxFuv5CQomJrZb8gqWa/MAE4l5BXuphu2i8A\n0wFCZwl00X4BOJ6QND0/K2MQcKOkT5rZ+QXq5CSgKvYL/wTsBuybvV4ErsheToNQCfsFM3vHzBbV\nXoRhf7mZLe9umU58KmG/4DQHMtt8r5GH9Bth43c4NU0wdXbZGZe0OQgvzjazAzo7r2kSXtx+oflp\nGrHh9gtNT9OIze0Xmp9KPRt1GhsXm5MMF5uTjKa5ZiuDfrusY5efvZYk1oKv7ZEkTo3PTD4vYbS/\nLXSW92xOMlxsTjJcbE4yXGxOMlxsTjJcbE4yXGxOMlxsTjJcbE4yKmO/kGXfW5vXmLg1dHpC1ewX\nziHkqNZeSxLGdjqhas9GF7fJUXUaiMrYL2ScKWmlpMWS/rILv3MSUCX7BQgJzXsTkp5/KmlYwd85\nCaiE/ULGScBbZrZO0hTg74BPEIT9IZLOJWTxM3DEVhsV4pRH0Z6tlf0CQXwjCb3Va8C07HO37BfM\nbLqZ5Z2HZubeF7FfwMyW5cqoiXtwO+dNNbMDzOyALYf2L1BdJxaVsF+QtHd2PdkvO1RL+VvUlXKc\ncqmE/UIWZzlwmaTRwD8C84BZPSjTiUwl7Bey3x4PHAU8QxjSJ9rmnO7fgHR6g2BmM2kZls5q8/VJ\nbT5PafPbV4CD2ylzDjC+neOLaNM7mtmYzuqYnTebcEPgNChNM6nr9gvNT9OIDbdfaHqaRmxuv9D8\n+BIjJxkuNicZLjYnGS42JxlNc4NQBqvf3IrH/rlTd84orDomrc2p9Uo4n/2rYqd5z+Ykw8XmJMPF\n5iTDxeYkw8XmJMPF5iTDxeYkw8XmJKMy9gvZ76+StCLLG50Qs25Oz6mM/YKkk4GJwGcIW3j/TNLA\nFLGdYlTpcdXZwJVm9oyk+cC7+GVCQ1El+4VPA6MlvQG8REiE7uqOzE6JVMJ+QdJQoB8hu2o8cDtw\nmw+jjUVV7BcGZH+vM7OXJF1HsF84CPif/Il5+4W+A4YWbL4Tg6rYL6zM/q7I/r6V/R3e9sS8/ULv\n/gPafu2USCXsF8xsVVburtmhWqL0sq6U45RLVewXAH4OXCJpd+ASgh3DYz0s04lIJewXMr5NsF6Y\nS7hRONHM3u9hmU5EtDnbYQwYvqPteexFSWKtGlPdZeEvXH7JbDPrdH1900zquv1C89M0YsPtF5qe\nphGb2y80P/7s0EmGi81JhovNSYaLzUnGZj3PJumPdO/JxXDCE4oUNEOsncxs285O2qzF1l0kzSoy\niemxWuPDqJMMF5uTDBdb92hvOZXH6gS/ZnOS4T2bkwwXm5MMF5uTDBebkwwXm5OM/wc3UYgzVkg7\nuQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Test\n",
    "test_text = 'sorry hate you'\n",
    "tests = [np.asarray([word_id[n] for n in test_text.split()])]\n",
    "test_batch = Variable(torch.LongTensor(tests))\n",
    "\n",
    "# Predict\n",
    "predict, _ = model(test_batch)\n",
    "predict = predict.data.max(1, keepdim=True)[1]\n",
    "if predict[0][0] == 0:\n",
    "    print(test_text,\"is Bad Mean...\")\n",
    "else:\n",
    "    print(test_text,\"is Good Mean!!\")\n",
    "    \n",
    "fig = plt.figure(figsize=(6, 3)) # [batch_size, n_step]\n",
    "ax = fig.add_subplot(1, 1, 1)\n",
    "ax.matshow(attn, cmap='viridis')\n",
    "ax.set_xticklabels(['']+['first_word', 'second_word', 'third_word'], fontdict={'fontsize': 14}, rotation=90)\n",
    "ax.set_yticklabels(['']+['batch_1', 'batch_2', 'batch_3', 'batch_4', 'batch_5', 'batch_6'], fontdict={'fontsize': 14})\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
